from typing import List

input = [-1, 0, 1, 2, -1, -4]


def sum_zero(input: List[int]):
    s = sorted(input)
    res = set()
    for i in range(len(s)):
        j = i + 1
        k = len(s) - 1
        while j < k:
            temp = s[i] + s[j] + s[k]
            if temp == 0:
                res.add((s[i], s[j], s[k]))
                j += 1
                k -= 1
            elif temp < 0:
                j += 1
            else:
                k -= 1
    return res


print(sum_zero(input))
